/*
    Copyright 2007-2011 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Declaration of a grid class used for making images of dust emission.

#ifndef __ir_grid__
#define __ir_grid__

#include "grid.h"
#include "optical.h"
#include "emission.h"
#include "blitz-fits.h"
#include "mcrx-units.h"
#include "chromatic_policy.h"
#include "density_generator.h"
#include "terminator.h"
#include "grain_model.h"
#include "biniostream.h"
#include "boost/foreach.hpp"
#include "boost/pointee.hpp"

namespace mcrx {
  template <typename, template <typename> class, typename> class ir_grid;
  template<typename> class temp_thread;
  template<typename T> class arepo_grid;

  // typedefs for commonly used ir_grid types
  typedef ir_grid<
    adaptive_grid<
      cell_data<
	typename emitter_types<adaptive_grid, 
			       polychromatic_policy, local_random>::T_emitter,
	absorber<array_1> > >,
    cumulative_sampling, local_random > T_ir_adaptive_grid;
#ifdef WITH_AREPO
  typedef ir_grid<
    arepo_grid<
      cell_data<
	typename emitter_types<arepo_grid, 
			       polychromatic_policy, local_random>::T_emitter,
	absorber<array_1> > >,
    cumulative_sampling, local_random> T_ir_arepo_grid;
#endif
}

/** An grid containing emitters of IR emission which can calculate the
    SEDs of the cells. It stores dust mass (NOT density) and cell
    luminosity.  This is also derived from emission, so it is its own
    emission object for the sources in the grid. */
template <typename grid_type,
	  template<typename> class sampling_policy,
	  typename rng_policy> 
class mcrx::ir_grid : 
  public emission_collection <polychromatic_policy,
			      sampling_policy, rng_policy>
{ 
public:
  typedef grid_type T_grid_impl;
  typedef typename T_grid_impl::T_data T_data;
  typedef typename T_data::T_emitter T_emitter;
  typedef typename T_emitter::T_lambda T_lambda;

  typedef T_lambda T_content;
  typedef typename T_grid_impl::T_cell T_cell;
  typedef typename T_grid_impl::T_cell_tracker T_cell_tracker;
  typedef typename T_grid_impl::iterator iterator;
  typedef typename T_grid_impl::const_iterator const_iterator;
  
private:
  /// The emission SED wavelength vector
  T_content lambda_;
  /// Emission SED array indexed by (cell, wavelength)
  array_2 L_lambda_;
  /// The normalization of L_lambda, needed to recover L_lambda during
  /// perturbative calculations.
  array_1 L_emitted_;
  /// Previous emission SED for perturbative calculations, indexed by
  /// (cell, wavelength)
  array_2 L_lambda_old_;
  /// Dust mass, indexed by (cell, dust model)
  array_2 m_dust_;
  /// The integrated luminosity in the grid
  T_float tot_lum_;
  /// The total emission weight in the previous iteration. Used for
  /// restoring L_lambda_old when we start a preturbative calculation.
  T_float prev_total_weight_;

  /// Vector of saved dust temperature arrays, if requested.
  std::vector<array_2> temps_;

  /// The concrete grid object
  boost::shared_ptr<T_grid_impl> g_;

  /// Units
  T_unit_map units_;

  /** Allocates the L_lambda_ array and sets up the emitters that are
      the grid data objects. */
  void make_emitters();
  CCfits::ExtHDU* create_hdu(CCfits::FITS&, const blitz::TinyVector<int,2>&,
			     const std::string&, const std::string&,
			     const std::string&, const std::string&) const;

public:
  ir_grid (boost::shared_ptr<T_grid_impl> g,
	   CCfits::ExtHDU& structure_hdu,
	   CCfits::ExtHDU& data_hdu,
	   const density_generator& dg,
	   const T_content& lambda);

  ir_grid (boost::shared_ptr<T_grid_impl> g,
	   const density_generator& dg,
	   const T_content& lambda);

  ir_grid (boost::shared_ptr<T_grid_impl> g,
	   const array_2&, const T_content& lambda,
	   const T_unit_map& units);

  template<typename T_intensity, typename T_grain_model_ptr>
  bool calculate_SED(const blitz::ETBase<T_intensity>& intensity,
		     const std::vector<T_grain_model_ptr>& models,
		     const terminator& t,
		     int n_threads=1,
		     CCfits::ExtHDU* deposition_hdu=0,
		     bool perturbative=false,
		     bool sampling_normalization=false,
		     bool save_temp=false);

  template<typename T_grain_model_ptr>
  bool calculate_SED(CCfits::ExtHDU& intensity_hdu,
		     const std::vector<T_grain_model_ptr>& models,
		     const terminator& t,
		     int n_threads=1,
		     CCfits::ExtHDU* deposition_hdu=0,
		     bool perturbative=false,
		     bool sampling_normalization=false,
		     bool save_temp=false);

  void normalize_for_sampling();

  T_float total_luminosity() const {return tot_lum_; };

  /** Writes the SEDs of the cells to the specified FITS HDU. Note
      that if the emitters are normalized for sampling, this will undo
      that. */
  void write_seds(CCfits::FITS&, const std::string&, bool);

  /** Writes the temperatures of the cells to the specified FITS
      HDU. For this to work, calculate_SED must have been called with
      save_temp=true, otherwise the information is not retained. The
      specified HDU name is suffixed with a number indicating the dust
      species.  */
  void write_temps(CCfits::FITS&, const std::string&);

  void load_dump(binifstream&);
  void write_dump(binofstream&) const;

  /** Returns the conversion factor from internal area units (usually
      kpc) to SI (m^2). This is used when converting the intensity
      array to physical units.  The convert function will throw if
      there is no length unit, which is ok, since without units we
      can't make this conversion. */
  T_float area_factor() const {
    T_float to_m = units::convert(units_.get("length" ), "m");
    return to_m*to_m;
  };
  const T_unit_map& units() const {return units_;};    

  /** We must redefine this function here because the
  emission_collection version relies on the distribution being set up
  and that's not true unless normalize_for_sampling() has been run. */
  T_lambda zero_lambda() const {
    return T_content(lambda_*0.0); };

  boost::shared_ptr<emission<polychromatic_policy, rng_policy> > clone() const {
    return boost::shared_ptr<emission<polychromatic_policy, rng_policy> >(new ir_grid(*this));};

  /// \name Forwarding functions for the grid interface.
  ///@{
  T_cell_tracker locate (const vec3d &p, int thread, bool accept_outside=false) {
    return g_->locate (p, thread, accept_outside);};
  std::pair<T_cell_tracker, T_float> 
  intersection_from_without (const ray_base& ray, int thread) const {
    return g_->intersection_from_without (ray, thread);};
  int get_cell_number(const T_cell_tracker& c) const {
    return g_->get_cell_number(c); };
  int n_cells() const { return g_->n_cells(); };

  const vec3d& getmin() const {return g_->getmin();};
  const vec3d& getmax() const {return g_->getmax();};

  const_iterator begin () const {return g_->begin ();};
  iterator begin () {return g_->begin ();};
  const_iterator end () const {return g_->end ();};
  iterator end () {return g_->end ();};
  ///@}
};


/**  This function performs the calculation of the dust emission SED
     in the grid cells based on the cell intensities in a FITS HDU.
     The list of cells for which the calculation is needed is set up
     by the constructor or, if we're loading from a dump file, to only
     the remaining cells in load_dump. If a pointer to the deposition
     HDU is supplied, energy conservation is checked. */
template <typename grid_type,
	  template<typename> class sampling_policy, 
	  typename rng_policy>
template<typename T_grain_model_ptr>
bool
mcrx::ir_grid<grid_type, sampling_policy, rng_policy>::
calculate_SED (CCfits::ExtHDU& intensity_hdu,
	       const std::vector<T_grain_model_ptr>& models,
	       const terminator& term,
	       int n_threads,
	       CCfits::ExtHDU* deposition_hdu,
	       bool perturbative,
	       bool sampling_normalization,
	       bool save_temp)
{
  // read intensity and deposition HDUs
  array_2 intensity;
  read (intensity_hdu, intensity);
  intensity = pow(10.,intensity);

  // make a weak reference to avoid locking
  array_2 intensity_ref;
  intensity_ref.weakReference(intensity);

  return calculate_SED(intensity_ref, models, term, n_threads, 
		       deposition_hdu, perturbative, save_temp);
}


/**  This function performs the calculation of the dust emission SED
     in the grid cells based on the cell intensities in the supplied
     array.  NOTE THAT the list of cells for which the calculation is
     needed is set up by the constructor or, if we're loading from a
     dump file, to only the remaining cells in load_dump. (This
     functionality is not working.) If a pointer to the deposition HDU
     is supplied, energy conservation is checked. If perturbative is
     true, the SED will be updated to the difference of the SED
     calculated on previous iteration and the new SED. If
     sampling_normalization is true, then normalize_for_sampling() is
     called after calculating the SEDs, as required if we are going to
     draw from the emitters. (This can also be called separately.)

     Setting temp to an array pointer will cause that array to be
     resized and filled with the temperatures of the grains. \bug Note
     however that this will NOT work correctly if you have more than
     one dust type, in that case you'll get the temperatures for the
     last one.  */
template <typename grid_type,
	  template<typename> class sampling_policy, 
	  typename rng_policy>
template<typename T_intensity, typename T_grain_model_ptr>
bool
mcrx::ir_grid<grid_type, sampling_policy, rng_policy>::
calculate_SED (const blitz::ETBase<T_intensity>& intens,
	       const std::vector<T_grain_model_ptr>& models,
	       const terminator& term,
	       int n_threads,
	       CCfits::ExtHDU* deposition_hdu,
	       bool perturbative,
	       bool sampling_normalization,
	       bool save_temp)
{
  using namespace std;
  using namespace blitz;
  
  const T_intensity& intensity(intens.unwrap());

  typedef typename boost::pointee<T_grain_model_ptr>::type
    T_grain_model;
  if (perturbative)
    std::cout << "Calculating update to grain temperatures and emission SEDs" << std::endl;
  else
    std::cout << "Calculating grain temperatures and emission SEDs" << std::endl;

  // check that units are consistent because we'd better be feeding
  // mass and intensity in the right units. This always warns about
  // inconsistent length units, which is a bit annoying and alarming.
  cout << "Don't worry about length unit warning that follows:\n";
  BOOST_FOREACH (T_grain_model_ptr m, models) {
    check_consistency(units(), m->units());
  }
  
  array_2 deposition;
  if(deposition_hdu) {
    read (*deposition_hdu, deposition);
    deposition = pow(10.,deposition);
  }

  // number of cells
  const int nc = this->n_cells();
  assert(nc==
	 intensity.ubound (firstDim)-intensity.lbound(firstDim)+1);

  // and if perturbative, allocate and reconstruct the old sed array, if needed
  if (perturbative && (L_lambda_old_.size()==0)) {
      std::cout << "Allocating a " << L_lambda_.shape()
	   << " memory block for previous cell emission, "
	   << 1.0*product(L_lambda_.shape())*sizeof(T_float)/(1024*1024*1024)
	   << " GB.\n";
      L_lambda_old_.resize(L_lambda_.shape());
      // then we need to copy the old SED to L_lambda_old, taking care
      // to re-normalize it since that was changed in the previous
      // call. This is only needed the first perturbative iteration,
      // so after that it's not needed anymore.
      std::cout << "Reconstructing previous cell emission." << endl;
      L_lambda_old_ = L_lambda_(tensor::i, tensor::j)*
	L_emitted_(tensor::i)/prev_total_weight_;
    }

  // if we aren't doing a perturbative calculation, we can throw away
  // the old sed
  if (!perturbative) {
    L_lambda_old_.free();
  }

  if (save_temp)
    temps_.resize(models.size());

  assert (nc==L_lambda_.extent(firstDim));

  L_lambda_=0;
  // models ADD their sed to L_lambda
  for (int j=0; j < models.size(); ++j) {
    std::cout << "\t Calculating dust emission SED for model " << j << std::endl;
    array_1 md(m_dust_(Range::all(), j));
    models[j]->calculate_SED(intensity, md, L_lambda_, 
			     save_temp ? &(temps_[j]) : 0);
  }
  ASSERT_ALL(L_lambda_ == L_lambda_);
  ASSERT_ALL(L_lambda_ < 1e300);
  ASSERT_ALL(L_lambda_ >=0);
  
  std::cout << "Updating SED arrays" << std::endl;

  if (perturbative) {
    // If we are doing perturbative calculation, we now need to
    // subtract and update the old SED. This requires a temporary, so
    // we do it on a cell by cell basis to save memory.

    array_1 tempLlambda(L_lambda_.extent(secondDim));
    threadLocal_warn(tempLlambda);
    
    for(size_t c=0; c<nc; ++c) {
      tempLlambda = L_lambda_(c, Range::all ());
      L_lambda_(c, Range::all ()) -= L_lambda_old_(c, Range::all ());
      L_lambda_old_(c, Range::all ()) = tempLlambda;
    }
  }
  else {
    // if we are not doing a perturbative calculation, we don't need
    // to do anything, because L_lambda_ will be reconstructed upon
    // first perturbative calculation. This saves us from keeping that
    // array arround for non-perturbative calculations.
  }

  // Calculate the total (net) luminosity
  array_1 L_lambda_tot(sum(L_lambda_(tensor::j, tensor::i),tensor::j));
  tot_lum_ = 
    integrate_quantity(L_lambda_tot, lambda_, false);

  if(!term() && sampling_normalization) 
    normalize_for_sampling();

  if(perturbative) {
    // If we are called as perturbative, there is no need to save
    // the L_emitted_ any longer, because it's only used on first
    // perturbative iteration.
    L_emitted_.free();
  }

  return term();
}


/**  Prepares the individual emitters for sampling by setting the
     emission_collection sample weights and normalizing the individual
     emitters accordingly. This must be called before drawing
     emitters, but must not be called if direct ray integration is
     used.  */
template <typename grid_type,
	  template<typename> class sampling_policy, 
	  typename rng_policy>
void
mcrx::ir_grid<grid_type, sampling_policy, rng_policy>::
normalize_for_sampling()
{
  using namespace blitz;
  const int nc = this->n_cells();

  cout << "Calculating emission weights" << endl;

  // This vector is used to collect the pointers to the emitters for
  // the call to distr_.setup.
  std::vector<emission<polychromatic_policy, rng_policy>*> pointers;
  pointers.reserve(nc);
  L_emitted_.resize(nc);

  int i = 0;
  for (typename T_grid_impl::iterator c = this->begin (); 
       c != this->end(); ++c, ++i) {
      
    array_1 temp_lum;
    temp_lum.weakReference(L_lambda_(i, Range::all ()));
    assert (temp_lum.isStorageContiguous());

    // Calculate emission weight and normalize. We use the absolute
    // value, because in perturbative calculations L_lambda may be
    // negative or even change sign at some wavelength. NOTE that
    // this changes L_lambda_!

    const T_float L = integrate_quantity(abs(temp_lum), lambda_, false);
    if (L !=0) 
      // otherwise we get NaN, and there's no emission anyway so it
      // doesn't matter
      temp_lum/= L;
      
    DEBUG(1,cout << "Luminosity in cell " << i << '\t' << L << endl;);
      
    // Note that the emitters were already created in the
    // constructor. They are already referring to the temp_lum
    // array.
      
    // save variables for setup of emission_collection 
    pointers.push_back(&c->data()->get_emitter());
      
    // The emission weight is the bolometric dust luminosity
    L_emitted_(i) = L;
  }
    
  // Now set up probability distribution object, with bolometric dust
  // luminosity/bolometric dust luminosity as the emission weight.
  this->distr_.setup(pointers, L_emitted_);

  // The emission needs to be correctly normalized so that after
  // samping the emitters, we get back the real luminosity. Because
  // we divided by L_emitted_ in each emitter, we need to now go
  // back and multiply in the total emission weight (sum of
  // L_emitted_ over all cells), analogously to what's done in
  // full_sed_particles_emission::emit.
  prev_total_weight_ = this->distr_.total_weight();
  L_lambda_ *= prev_total_weight_;

  // after updating the emission, we need to notify the emitters in
  // case they need to update their internal state
  BOOST_FOREACH(T_cell& c, *this) {
    c.data()->get_emitter().emission_update();
  }
      
}


/** This function loads the SED information from a dump
    file. CURRENTLY BROKEN. */
template <typename grid_type,
	  template<typename> class sampling_policy, 
	  typename rng_policy>
void 
mcrx::ir_grid<grid_type, sampling_policy, rng_policy>::load_dump (binifstream& file) 
{
  assert(0);
  std::cout << "Loading temperature calculation data from dump file."  << std::endl;

  blitz::TinyVector<int, 2> extent;
  file >> extent [0]
       >> extent [1];
  assert(all( L_lambda_.shape() == extent));
  
  DEBUG(1,std::cout << "Reading dumped cell dust L_lambda, " << extent << " elements" << std::endl;);
  assert(L_lambda_.isStorageContiguous());
  file.read(reinterpret_cast<char*> (L_lambda_.dataFirst()),
	    L_lambda_.size()*sizeof (array_2::T_numtype) );
}


/** This function writes the SED information to a dump
    file. CURRENTLY BROKEN. */
template <typename grid_type,
	  template<typename> class sampling_policy, 
	  typename rng_policy>
void 
mcrx::ir_grid<grid_type, sampling_policy, rng_policy>::write_dump(binofstream& file) const
{
  assert(0);
  blitz::TinyVector<int, 2> extent = L_lambda_.shape();
  file << extent [0]
       << extent [1];
  DEBUG(1,std::cout << "Dumping cell dust L_lambda, " << extent << " elements" << std::endl;);
  assert(L_lambda_.isStorageContiguous());
  file.write(reinterpret_cast<const char*> (L_lambda_.dataFirst()),
	     L_lambda_.size()*sizeof (array_2::T_numtype) );
}
	    

#endif

  
